# From Offer-Westort et al (2021) replication materials
format_num <- function(x, digits = 2) {
  x <- as.numeric(x)
  return(paste0(sprintf(paste0("%.", digits, "f"), x)))
}
# From Offer-Westort et al (2021) replication materials
add_parens <- function(x, digits = 2) {
  x <- as.numeric(x)
  return(paste0("(", format_num(x, digits = digits), ")"))
}
# From Offer-Westort et al (2021) replication materials
sim_out <- function(n = 100, periods = 10, arms = 9, probs, static = FALSE,
                    control = FALSE, first = NA){

  # account for first batch size if not all equal
  if(is.na(first)){
    first <- n
  } else {
    n <- (n*periods-first)/(periods-1)
  }

  xmat <- nmat <- ppmat <- matrix(NA, ncol = arms, nrow = periods)

  i <- 1
  nmat[i,] <- table(simple_ra(N = first, prob_each = rep(1/arms, arms)))

  # outcomes: sample from a binomial distribution according to each arm's true p
  xmat[i,] <- mapply(rbinom, n = 1, size = nmat[i,], prob = probs)
  # use x, n, to get posterior probability
  ppmat[i,] <- best_binomial_bandit_sim(xmat[i,],nmat[i,])

  # For subsequent arms, sampling is proportionate to posterior probability (thompson sampling)
  if(periods>i){
    repeat{
      i <-  i+1
      if(static == TRUE){
        newvals <- table(simple_ra(N = n, prob_each = rep(1/arms, arms)))
      } else {
        newvals <- table(simple_ra(N = n, prob_each = ppmat[(i-1),]))
      }
      nmat[i,] <- nmat[(i-1),] + newvals # cumulative
      xmat[i,] <- xmat[(i-1),] + mapply(rbinom, n = 1, size = newvals,
                                        prob = probs)
      ppmat[i,] <- best_binomial_bandit_sim(xmat[i,], nmat[i,])

      if (i == periods) break
    }}
  out <- list(nmat = nmat, xmat = xmat, ppmat = ppmat)
  return(out)
}
# From Offer-Westort et al (2021) replication materials
ipw_est <- function(vals, static=FALSE, se_type='HC2'){

  bbmat <- melt(rbind(vals$nmat[1,],
                      diff(vals$nmat)))
  bbmat$success <- melt(rbind(vals$xmat[1,],
                              diff(vals$xmat)))$value

  yvals <- as.vector(unlist(apply(bbmat, 1, function(x)
    c(rep(0, x['value']-x['success']), rep(1, x['success']) ))))

  zvals <- as.factor(unlist(apply(bbmat, 1, function(x)
    c(rep(x['Var2'], x['value'])))))

  if( (static == TRUE) & (length(is.finite(vals$control))==0) ){ # static, no control

    wvals <- rep(1/length(unique(zvals)), length(yvals))
    lmfit <- lm_robust(yvals ~ -1 + zvals, weights = 1/wvals, se_type = se_type)

  }
  else if( (static == TRUE) & (length(is.finite(vals$control))>0) ){ # static with control

    wvals <- rep(1/length(unique(zvals)), length(yvals))
    zvals <- relevel(zvals, ref = as.character(vals$control))
    lmfit <- lm_robust(yvals ~ zvals, weights = 1/wvals, se_type = se_type)

  }
  else if( (static == FALSE) & (length(is.finite(vals$control))==0) ){ # adaptive, no control

    bbmat$weight <- melt(rbind(rep(1/ncol(vals$ppmat), ncol(vals$ppmat)),
                               vals$ppmat[1:( nrow(vals$ppmat)-1),]))$value
    wvals <- unlist(apply(bbmat, 1, function(x)
      c(rep(x['weight'], x['value']))))

    lmfit <- lm_robust(yvals ~ -1 + zvals, weights = 1/wvals, se_type = se_type)

  } else { # adaptive with control

    zvals <- relevel(zvals, ref = vals$control)
    bbmat$weight <- melt(vals$spmat)$value
    wvals <- unlist(apply(bbmat, 1, function(x)
      c(rep(x['weight'], x['value']))))

    lmfit <- lm_robust(yvals ~ zvals, weights = 1/wvals, se_type = se_type)
  }
  return(lmfit)
}

simulate <- function(probs = NA, control = FALSE, static = FALSE,
                     periods = 10, n = 200, iter = 1000,
                     RI = FALSE, Y = NA, Z = NA, first = NA, ppmat = FALSE){
  outmat <- matrix(rep(NA, (11+periods)*iter), ncol = (11+periods))
  d_fit <- d_ppmat <- NULL
  colnames(outmat) <- c('correct', 'posterior_best', 'bias', 'rmse_best',
                        'rmse_te', 'cov_best', 'cov_te', 'ses', 'F', 'p',
                        0:periods)

  K <- ifelse(RI, length(Z), length(probs))
  if(RI) { probs <- rep(1/mean(Y), K) }

  outmat[,'0'] <- 1/K

  cat('iteration: ')
  for( i in 1:iter){

    if(control==FALSE){ # no augmented control
      if(RI){
        vals <- sim_out_RI(Y, Z, periods = periods, static = static)
        true_val <- mean(Y)
        true_vali <- 'zvals1'
      } else {
        vals <- sim_out(probs = probs, arms = K, static = static,
                        periods = periods, n = n, first = first)
        true_val <- max(probs)
        true_vali <- paste0('zvals', which.max(probs)) # index of max probs
        if (ppmat == TRUE){
          d_ppmat <- bind_rows(d_ppmat, data.frame(vals$ppmat) %>% mutate(period = 2:(1+nrow(vals$ppmat)), iter = i))
        }
      }
      lm_r <- ipw_est(vals, static)

      # rmse
      outmat[i,'rmse_best'] <- sqrt((coef(lm_r)[true_vali]-true_val)^2)
      # coverage
      outmat[i, 'cov_best'] <- 1 * ((lm_r[['conf.low']][true_vali] < true_val) &
                                      (lm_r[['conf.high']][true_vali] > true_val))
    } else {
      if(RI){
        vals <- sim_out_RI(Y, Z, periods = periods, static = static,
                           control = control)
        true_val <- 0
        true_vali <- 'zvals1'
      } else {
        vals <- sim_out_te(probs = probs, arms = K, static = static,
                           control = control, periods = periods, n = n,
                           first = first)
        true_val <- max(probs)-probs[control]
        true_vali <- paste0('zvals', which.max(probs)) # index of max probs
      }

      lm_r <- ipw_est(vals, static)
      d_fit <- bind_rows(d_fit, tidy(lm_r) %>% mutate(iter = i))
      est <- sum(coef(lm_r)[c('(Intercept)', true_vali)])
      se <- sqrt(abs(sum(vcov(lm_r)[c('(Intercept)', true_vali),
                                    c('(Intercept)', true_vali)])))

      outmat[i,'rmse_best'] <- sqrt((est-max(probs))^2)

      cint <- est + se*qt(0.975,lm_r$N-lm_r$k)*c(-1,1)

      outmat[i, 'cov_best'] <- 1 * ((cint[1] < max(probs)) &
                                       (cint[2] > max(probs)))
    }
    d_fit <- bind_rows(d_fit, tidy(lm_r) %>%
                         complete(term = paste0('zvals', 1:K)) %>%
                         mutate(ord = as.numeric(gsub('zvals', '', term))) %>%
                         arrange(ord) %>%
                         bind_cols(posterior = vals$ppmat[nrow(vals$ppmat),]) %>%
                         mutate(iter = i) %>%
                         select(-ord))
    # selected correct arm
    outmat[i,'correct'] <- (which.max(vals$ppmat[nrow(vals$ppmat),])==which.max(
      probs))*1
    # posterior probability of best arm
    outmat[i,'posterior_best'] <- vals$ppmat[nrow(vals$ppmat), which.max(probs)]
    # bias
    outmat[i,'bias'] <- coef(lm_r)[true_vali]-true_val
    # ses
    outmat[i,'ses'] <- lm_r$std.error[true_vali]

    if(control == FALSE) {# select arm 3 as control arm for statistics
      vals$control <- 3
      vals$spmat <- rbind(rep(1/K, K),
                          vals$ppmat)[1:periods,]
      # re-estimate lm_r with intercept
      lm_r <- ipw_est(vals, static)
      true_val <- max(probs)-probs[vals$control]
    }
    # rmse of ATE
    outmat[i,'rmse_te'] <- sqrt((coef(lm_r)[true_vali]-true_val)^2)
    # coverage of ATE
    outmat[i,'cov_te'] <- 1 * ((lm_r[['conf.low']][true_vali] < true_val) &
                                 (lm_r[['conf.high']][true_vali] > true_val))
    if(is.na(outmat[i,'cov_te'])){
      outmat[i,'cov_te'] <- 0
    }

    # Pr(>|t|)
    outmat[i,'p'] <- lm_r$p.value[true_vali] # of true best arm

    # re-estimate lm_r with classical SEs
    lm_r <- ipw_est(vals, static, se_type = 'classical')

    # F
    outmat[i,'F'] <- summary(lm_r)[['fstatistic']]['value']

    outmat[i, as.character(1:periods)] <- vals$ppmat[,which.max(probs)]

    cat(i, '...')
  }
  cat('\n')
  return(list(outmat = outmat, d_fit = d_fit, d_ppmat = d_ppmat))
}

sim_out_hybrid <- function(n = 1000, arms = 11, probs, ntop = 5){
  # FIRST PERIOD: USE ALL ARMS
  i <- 1
  xmat <- nmat <- ppmat <- matrix(NA, ncol = arms, nrow = 1)
  nmat[i,] <- table(simple_ra(N = n, prob_each = rep(1/arms, arms)))
  # outcomes: sample from a binomial distribution according to each arm's true p
  xmat[i,] <- mapply(rbinom, n = 1, size = nmat[i,], prob = probs)
  # use x, n, to get posterior probability
  ppmat[i,] <- best_binomial_bandit_sim(xmat[i,],nmat[i,])

  # SECOND PERIOD: USE TOP 5 ARMS AND RUN STATIC
  # identify top 5 arms
  top  <- order(ppmat[i,], decreasing = TRUE)[1:ntop]
  xmat <- nmat <- ppmat <- topmat <- matrix(NA, ncol = ntop, nrow = 1)
  new_probs <- probs[top]
  nmat[i,]  <- table(simple_ra(N = n, prob_each = rep(1/ntop, ntop)))
  xmat[i,]  <- mapply(rbinom, n = 1, size = nmat[i,], prob = new_probs)
  ppmat[i,] <- best_binomial_bandit_sim(xmat[i,], nmat[i,])
  topmat[i,]<- top
  out <- list(nmat = nmat, xmat = xmat, ppmat = ppmat, topmat = topmat)
  return(out)
}

simulate_hybrid <- function(probs = NA, periods = 1, n = 1000, iter = 1000, ntop = 5){
  outmat <- matrix(rep(NA, (11+periods)*iter), ncol = (11+periods))
  d_fit <- NULL
  colnames(outmat) <- c('correct', 'posterior_best', 'bias', 'rmse_best',
                        'rmse_te', 'cov_best', 'cov_te', 'ses', 'F', 'p', 0:periods)
  K <- length(probs)
  outmat[,'0'] <- 1/K

  cat('iteration: ')
  for(i in 1:iter){
    vals <- sim_out_hybrid(probs = probs, arms = K, n = n, ntop = ntop)
    true_val  <- max(probs)
    true_vali <- paste0('zvals', which.max(probs)) # index of max probs

    ########### ipw_est
    bbmat <- melt(rbind(vals$nmat[1,], diff(vals$nmat))) # var1 period; var2 arms
    bbmat$Var2  <- vals$topmat[1,]
    bbmat$success <- melt(rbind(vals$xmat[1,], diff(vals$xmat)))$value
    yvals <- as.vector(unlist(apply(bbmat, 1, function(x)  c(rep(0, x['value']-x['success']), rep(1, x['success']) ))))
    zvals <- as.factor(unlist(apply(bbmat, 1, function(x)  c(rep(x['Var2'], x['value'])))))
    wvals <- rep(1/length(unique(zvals)), length(yvals))
    lm_r <- lm_robust(yvals ~ -1 + zvals, weights = 1/wvals, se_type = 'HC2')

    posterior <- data.frame(posterior = vals$ppmat[1,], term = paste0('zvals', vals$topmat[1,]))
    d_fit <- bind_rows(d_fit,
                       left_join(tidy(lm_r), posterior, by = 'term') %>% mutate(iter = i))
    cat(i, '...')
  }
  cat('\n')
  return(list(outmat = outmat, d_fit = d_fit))
}

result_sum <- function(data){
  est <- data %>%
    left_join(tidy(mod)[,1:3] %>% transmute(term = gsub('arm', 'zvals', term), true = estimate, true_se = std.error)) %>%
    group_by(iter) %>%
    mutate(best = max(posterior) == posterior) %>%
    group_by(term) %>%
    summarize(true = mean(true),
              true_se = mean(true_se),
              best = mean(best),
              est  = mean(estimate),
              se   = mean(std.error),
              bias = mean(true - estimate),
              rmse = sqrt(mean((estimate - true)^2)),
              coverage = mean(conf.low < true & conf.high > true)) %>%
    arrange(-true) %>%
    mutate_if(is.numeric, ~sprintf('%0.3f', .x)) %>%
    mutate(term = gsub('zvals', 'Arm ', term),
           true = paste0(true, ' (', true_se, ')'),
           est = paste0(est, ' (', se, ')')) %>%
    dplyr::select(-true_se, -se)
  return(est)
}

result_sum_NA <- function(data){
  est <- data %>%
    left_join(tidy(mod)[,1:3] %>% transmute(term = gsub('arm', 'zvals', term), true = estimate, true_se = std.error)) %>%
    group_by(iter) %>%
    mutate(best = max(posterior) == posterior) %>%
    group_by(term) %>%
    summarize(true = mean(true, na.rm = T),
              true_se = mean(true_se, na.rm = T),
              best = mean(best, na.rm = T),
              est  = mean(estimate, na.rm = T),
              se   = mean(std.error, na.rm = T),
              bias = mean(true - estimate, na.rm = T),
              rmse = sqrt(mean((estimate - true)^2, na.rm = T)),
              coverage = mean(conf.low < true & conf.high > true, na.rm = T)) %>%
    arrange(-true) %>%
    mutate_if(is.numeric, ~sprintf('%0.3f', .x)) %>%
    mutate(term = gsub('zvals', 'Arm ', term),
           true = paste0(true, ' (', true_se, ')'),
           est = paste0(est, ' (', se, ')')) %>%
    dplyr::select(-true_se, -se)
  return(est)
}

result_hybrid <- function(data, iter){
  est <- data %>%
    left_join(tidy(mod)[,1:3] %>% transmute(term = gsub('arm', 'zvals', term), true = estimate, true_se = std.error)) %>%
    group_by(iter) %>%
    mutate(best = max(posterior) == posterior) %>%
    group_by(term) %>%
    summarize(true = mean(true),
              true_se = mean(true_se),
              best = sum(best),
              top5 = n(),
              est  = mean(estimate),
              se   = mean(std.error),
              bias = mean(true - estimate),
              rmse = sqrt(mean((estimate - true)^2)),
              coverage = mean(conf.low < true & conf.high > true)) %>%
    arrange(-true) %>%
    mutate(top5 = top5 / iter,
           best = best / iter) %>%
    mutate_if(is.numeric, ~sprintf('%0.3f', .x)) %>%
    mutate(term = gsub('zvals', 'Arm ', term),
           true = paste0(true, ' (', true_se, ')'),
           est = paste0(est, ' (', se, ')')) %>%
    dplyr::select(-true_se, -se)
  return(est)
}
